import torch
from torch import nn

net = nn.Sequential(
    nn.Linear(20, 256), nn.ReLU(),
    nn.LazyLinear(128), nn.ReLU(),
    nn.LazyLinear(10))

X = torch.rand(2, 10)
try:
    net(X)
except Exception as e:
    print(e)